{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# CUDAVISIBLE DEVICES\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig\n",
    "import transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from snapkv.monkeypatch.monkeypatch import replace_llama, replace_mistral, replace_mixtral"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "replace_mistral()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastchat.model import load_model, get_conversation_template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    \"mistralai/Mistral-7B-Instruct-v0.2\",\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    low_cpu_mem_usage=True,\n",
    "    device_map=\"auto\",\n",
    "    use_flash_attention_2=True\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-Instruct-v0.2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('snapkv.txt', 'r') as f:\n",
    "    content = f.read().strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "question = \"\\n What is the repository of SnapKV?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conv = get_conversation_template(\"mistralai/Mistral-7B-Instruct-v0.2\")\n",
    "conv.messages = []\n",
    "conv.append_message(conv.roles[0],content + question)\n",
    "# conv.append_message(conv.roles[0],\"Who is Kobe Bryant?\")\n",
    "conv.append_message(conv.roles[1], None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = conv.get_prompt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = tokenizer.encode(prompt, return_tensors='pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids_len = input_ids.size(1)\n",
    "print(input_ids_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = model.generate(input_ids.cuda(), max_new_tokens=200, do_sample=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tokenizer.decode(outputs[0][input_ids_len:], skip_special_tokens=True))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "code_attn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
